{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ONNX Script 教程\n",
    "\n",
    "在本教程中，我们通过示例展示了 ONNX Script 支持的特性。\n",
    "\n",
    "## ONNX Script 基础特性\n",
    "\n",
    "下面的示例展示了作为 ONNX Script 函数的定义。`Softplus`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用 ONNX opset 15 来定义下面的函数\n",
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script\n",
    "\n",
    "\n",
    "# 用 `@script` 装饰器来表明下面的函数旨在被翻译成 ONNX。\n",
    "@script()\n",
    "def Softplus(X):\n",
    "    return op.Log(op.Exp(X) + 1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在上述示例中，诸如 `op.Log(...)` 和 `op.Exp(...)` 这样的表达式表示调用 ONNX 算子（并被翻译成 ONNX NodeProto）。这里，`op` 用于标识包含所调用算子的 `opset`。在这个例子中，使用的是标准 ONNX opset 版本 15（由导入语句 `from onnxscript.onnx_opset import opset15 as op` 确定）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "算子如 `Add` 被支持作为 `·` 简写，并映射到相应标准 ONNX 算子（例如 `Add`）的适当 opset 中。在上述例子中，使用 `op.Add` 表明要使用 opset 15。如果示例没有以这种方式显式使用 opset，则必须通过调用 `@script()` 装饰器的 `default_opset` 参数来指定。\n",
    "\n",
    "同样地，常量字面值如 `1.0` 也被允许作为句法简写（如上述示例中的上下文），并且会隐式提升为 ONNX 张量常量。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ONNX Script 省略可选输入\n",
    "ONNX 算子的一些输入参数是可选的：例如，`Clip` 操作的 `min` 和 `max` 输入。可以使用 `None` 值来表示省略的可选输入，如下所示，或者在尾部输入的情况下可以简单地省略它们：`Clip(a, None, None)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script\n",
    "\n",
    "\n",
    "@script()\n",
    "def omitted_input(x):\n",
    "    # The following two statements are equivalent:\n",
    "    y1 = op.Clip(x)\n",
    "    y2 = op.Clip(x, None, None)\n",
    "    # The following example shows an omitted optional input, followed by another input\n",
    "    y3 = op.Clip(x, None, 1.0)\n",
    "    return y1 + y2 + y3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ONNX Script 指定属性参数值\n",
    "下面的示例展示了如何在调用中指定属性值。在这个例子中，调用了 ONNX 算子，并为 `shape` 属性的 `start` 和 `end` 指定了属性值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script\n",
    "\n",
    "\n",
    "@script()\n",
    "def FirstDim(X):\n",
    "    return op.Shape(X, start=0, end=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在将调用翻译为 ONNX 算子时，翻译器利用算子的规范来将实际参数映射到适当的输入参数和属性参数。由于 ONNX 规范没有指示属性参数的任何顺序，建议使用关键字参数（又名命名参数）来指定属性参数。OpSchema\n",
    "\n",
    "如果翻译器没有所调用算子的 opschema，它会使用以下策略将实际参数映射到适当的输入参数和属性参数：Python 的关键字参数被翻译为 ONNX 的属性参数，而位置参数被翻译为普通的值参数。因此，在上面的例子中，`X` 被视为这个特定调用的普通值参数，而 `start` 和 `end` 被视为属性参数（当 opschema 不可用时）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ONNX Script 指定张量常量\n",
    "可以使用 ONNX 工具创建张量常量，并且这些常量可以用作属性值，如下所示。此外，它们可以被提升以使用 ONNX 算子作为张量值，同样如下所示。make_tensorConstant"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnx import TensorProto, helper\n",
    "\n",
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script\n",
    "\n",
    "\n",
    "@script()\n",
    "def tensor_attr(x):\n",
    "    c = op.Constant(value=helper.make_tensor(\"scalar_half\", TensorProto.FLOAT, (), [0.5]))\n",
    "    return op.Mul(c, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上面显示的代码虽然冗长，但允许用户明确指定他们想要的内容。转换器作为一种便利，允许用户使用数字常量，如下例所示，这被翻译成与上述相同的 ONNX 表示形式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "tensor_attr: Already defined.\n"
     ]
    }
   ],
   "source": [
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script\n",
    "\n",
    "\n",
    "@script()\n",
    "def tensor_attr(x):\n",
    "    return op.Mul(0.5, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "直接使用字面量可以创建标量或类型为 `FLOAT`、`INT64` 或 `STRING` 的一维张量，如下表所示。\n",
    "\n",
    "\n",
    "Python source|Generated ONNX constant\n",
    ":-|:-\n",
    "`0`|Scalar value of type `0` `INT64`\n",
    "`0.0`|Scalar value of type `0.0` `FLOAT`\n",
    "`\"x\"`|Scalar value of type `\"x\"` `STRING`\n",
    "`[0, 1]`|One dimensional tensor of type `INT64`\n",
    "`[0.0, 1.0]`|One dimensional tensor of type `FLOAT`\n",
    "`[\"x\", \"y\"]`|One dimensional tensor of type `STRING`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然而，如果用户想要使用其他类型或其它秩（rank）的张量常量，他们需要更明确地这样做（如前例所示）。\n",
    "\n",
    "## ONNX Script 语义：脚本常量\n",
    "在 ONNX 中，属性被要求为常数值。在 ONNX Script 中，指定为属性的表达式在脚本时间（当脚本装饰器被评估时）在定义脚本函数的上下文中进行求值。只要它具有有效的类型，结果 Python 值就被转换为一个 ONNX 属性。\n",
    "\n",
    "这有几个重要的语义含义。首先，它允许在期望属性值的上下文中使用任意 Python 代码。然而，Python 代码必须能够使用定义脚本函数的全局上下文进行求值。例如，不允许使用函数本身的参数（即使是属性参数）进行计算。\n",
    "\n",
    "ONNX Script 假设这样的 Python 代码代表常量。如果在表达式中使用的变量值随后被修改，这种修改对属性值或创建的 ONNX 函数/模型没有影响。这可能会导致急切模式（eager-mode）执行的行为与生成的 ONNX 构造不一致。\n",
    "\n",
    "因此，上面显示的示例等价于下面的内容："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "tensor_attr: Already defined.\n"
     ]
    }
   ],
   "source": [
    "from onnx import TensorProto, helper\n",
    "\n",
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script\n",
    "\n",
    "script_const = helper.make_tensor(\"scalar_half\", TensorProto.FLOAT, (), [0.5])\n",
    "\n",
    "\n",
    "@script()\n",
    "def tensor_attr(x):\n",
    "    c = op.Constant(value=script_const)\n",
    "    return c * x\n",
    "\n",
    "\n",
    "# The following assignment has no effect on the ONNX FunctionProto\n",
    "# generated from tensor_attr:\n",
    "\n",
    "\n",
    "script_const = helper.make_tensor(\"scalar_one\", TensorProto.FLOAT, (), [1.0])\n",
    "\n",
    "fp = tensor_attr.to_function_proto()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ONNX Script 指定函数的形式属性参数\n",
    "Python 函数的（形式）输入参数被转换器视为代表属性参数或生成的 ONNX 函数的输入值参数。然而，转换器需要知道每个参数是表示属性还是输入。转换器使用形式输入参数上的类型注解来做出这种区分。因此，在下面的例子中，`alpha` 被视为一个属性参数（因为它的类型注解）。alphafloat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script\n",
    "\n",
    "\n",
    "@script()\n",
    "def LeakyRelu(X, alpha: float):\n",
    "    return op.Where(X < 0.0, alpha * X, X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "支持的属性（ONNX）类型及其对应的（Python）类型注解在下表中显示。其他类型的 ONNX 属性尚未支持。\n",
    "\n",
    "ONNX Type|Python Type Annotation\n",
    ":-|:-\n",
    "AttributeProto.FLOAT|float\n",
    "AttributeProto.INT|int, bool\n",
    "AttributeProto.STRING|str\n",
    "AttributeProto.FLOATS|Sequence[float]\n",
    "AttributeProto.INTS|Sequence[int]\n",
    "AttributeProto.STRINGS|Sequence[str]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ONNX Script 属性参数自动提升为值\n",
    "如上述示例所示，当在需要值参数的上下文中使用属性参数时，转换器将自动将属性转换为张量值。具体来说，在子表达式 `alpha * X` 中，属性参数 `alpha` 被用作调用 op（由 `op.Mul` 表示）的值参数，并自动进行转换。因此，`alpha * X` 被自动转换为 `alphaMul`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LeakyRelu: Already defined.\n"
     ]
    }
   ],
   "source": [
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script\n",
    "\n",
    "\n",
    "@script()\n",
    "def LeakyRelu(X, alpha: float):\n",
    "    return op.Where(X < 0.0, alpha * X, X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "扩展为以下内容："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LeakyRelu: Already defined.\n"
     ]
    }
   ],
   "source": [
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script\n",
    "\n",
    "\n",
    "@script()\n",
    "def LeakyRelu(X, alpha: float):\n",
    "    alpha_value = op.Constant(value_float=alpha)\n",
    "    return op.Where(X < 0.0, alpha_value * X, X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ONNX Script 常量值的自动类型转换\n",
    "当常量在被约束为与某些其他（非常量）操作数相同类型的上下文中使用，转换器也会自动引入类型转换（通过 ONNX 的 `op.CastLike` 操作）。例如，表达式 `2 * X` 被扩展为 `op.CastLike(2, X) * X`，这允许相同的代码适用于不同类型的 `X`。\n",
    "\n",
    "## ONNX Script 索引和切片\n",
    "ONNX Script 支持在张量上使用 Python 的索引和切片操作，这些操作被转换为 ONNX 的 `Slice` 和 `Gather` 操作。这个操作的语义类似于 Numpy 的。\n",
    "\n",
    "在表达式 `e[i_1, i_2, ..., i_n]` 中，`n` 是输入张量的秩或者是小于该值的任何值。每个索引值可以是标量值（秩为零的张量）或更高维的张量，或者是形式为 `start:end:step` 的切片表达式。从语义上讲，切片表达式等价于包含相应值序列的一维张量。\n",
    "\n",
    "然而，转换器将使用切片表达式的索引映射到可能比相应的 `Gather` 操作更高效的 ONNX 的 `Slice` 操作。更一般的情况（其中 `i_j` 是任意张量）使用 `Gather` 操作进行转换。\n",
    "\n",
    "注意：当前实现尚不支持在索引表达式中使用任意张量。它不支持在索引中使用省略号或新轴。\n",
    "\n",
    "## ONNX Script 控制流\n",
    "在 ONNX Script 中对控制流构造的支持受到 ONNX 控制流操作的限制。\n",
    "\n",
    "### ONNX Script 条件语句\n",
    "\n",
    "下面的函数定义示例说明了条件语句的使用。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script\n",
    "\n",
    "\n",
    "@script()\n",
    "def Dropout2(data, ratio, training_mode, seed: float):\n",
    "    if training_mode:\n",
    "        rand = op.RandomUniformLike(data, dtype=1, seed=seed)\n",
    "        mask = rand >= ratio\n",
    "        output = op.Where(mask, data, 0) / (1.0 - ratio)\n",
    "    else:\n",
    "        mask = op.ConstantOfShape(op.Shape(data), value=True)\n",
    "        output = data\n",
    "    return (output, mask)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "条件语句的使用要求在代码中使用的任何变量在所有可能的路径到使用处都有相同变量的定义。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ONNX Script 循环\n",
    "ONNX 实现了循环算子，执行固定次数的迭代和/或在条件不再为真时跳出循环。下面的第一个示例说明了最简单的情况：固定次数的迭代的使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script\n",
    "\n",
    "\n",
    "@script()\n",
    "def sumprod(x, N):\n",
    "    sum = op.Identity(x)\n",
    "    prod = op.Identity(x)\n",
    "    for _ in range(N):\n",
    "        sum = sum + x\n",
    "        prod = prod * x\n",
    "    return sum, prod"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "第二个示例展示了如果条件不再为真时循环的中断。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "sumprod: Already defined.\n"
     ]
    }
   ],
   "source": [
    "from onnx import TensorProto\n",
    "from onnx.helper import make_tensor\n",
    "\n",
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script\n",
    "\n",
    "\n",
    "@script()\n",
    "def sumprod(x, N):\n",
    "    sum = op.Identity(x)\n",
    "    prod = op.Identity(x)\n",
    "    cond = op.Constant(value=make_tensor(\"true\", TensorProto.BOOL, [1], [1]))\n",
    "    i = op.Constant(value=make_tensor(\"i\", TensorProto.INT64, [1], [0]))\n",
    "    while cond:\n",
    "        sum = sum + x\n",
    "        prod = prod * x\n",
    "        i = i + 1\n",
    "        cond = i < 10\n",
    "    return sum, prod"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "第三个示例混合了这两种类型的循环。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script\n",
    "\n",
    "\n",
    "@script()\n",
    "def sumprod_break(x, N):\n",
    "    sum = op.Identity(x)\n",
    "    prod = op.Identity(x)\n",
    "    for _ in range(N):\n",
    "        sum = sum + x\n",
    "        prod = prod * x\n",
    "        cond = op.ReduceSum(prod) > 1e7\n",
    "        # ONNX does not support break instruction.\n",
    "        # onnxscript can only convert example if the break\n",
    "        # instruction is placed at the end of the loop body.\n",
    "        if cond:\n",
    "            break\n",
    "    return sum, prod"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ONNX Script 编码高阶算子：扫描\n",
    "ONNX 允许图值属性。这是定义（准）高阶操作符的机制，如 `If`, `Loop`, `Scan` 和 `SequenceMap`。虽然我们使用 Python 控制流来编码 `If` 和 `Loop`，但 ONNX Script 支持使用嵌套的 Python 函数来表示图值属性，如下例所示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxscript import graph, script\n",
    "from onnxscript import opset15 as op\n",
    "\n",
    "\n",
    "@script()\n",
    "def CumulativeSum(X):\n",
    "    @graph()\n",
    "    def Sum(sum_in, next):\n",
    "        sum_out = sum_in + next\n",
    "        return sum_out, sum_out\n",
    "\n",
    "    all_sum, cumulative_sum = op.Scan(0, X, body=Sum, num_scan_inputs=1)\n",
    "    return cumulative_sum"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在这种情况下，`Sum` 的函数定义被转换为一个图，并在调用 `Scan` 算子时用作属性值。\n",
    "\n",
    "作为图属性使用的函数定义必须满足一些约束。它们不能更新外部作用域的变量，但可以引用它们。（具体来说，这些函数不能使用全局或非局部声明。）它们还被限制不能使用与外部作用域变量同名的局部变量（不允许遮蔽）。\n",
    "\n",
    "在函数定义内部使用外部作用域变量时，还存在一个与SSA重命名的交互。以下代码是无效的，因为函数 `CumulativeSum` 引用了全局变量 `g`，而在函数定义和函数使用之间更新了 `g`。请注意，从 ONNX 的角度来看，对 `g` 的两次赋值代表了两个不同的张量 `g1` 和 `g2`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "from onnxscript import graph, script\n",
    "from onnxscript import opset15 as op\n",
    "\n",
    "try:\n",
    "\n",
    "    @script()\n",
    "    def CumulativeSum(X):\n",
    "        g = op.Constant(value=0)\n",
    "\n",
    "        @graph()\n",
    "        def Sum(sum_in, next):\n",
    "            sum_out = sum_in + next + g\n",
    "            return sum_out, sum_out\n",
    "\n",
    "        g = op.Constant(value=1)\n",
    "        all_sum, cumulative_sum = op.Scan(0, X, body=Sum, num_scan_inputs=1)\n",
    "        return cumulative_sum\n",
    "\n",
    "except Exception as e:\n",
    "    assert \"Outer scope variable\" in str(e)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
